import pandas as pd
import numpy as np
import re
import collections
import itertools
import operator
from itertools import chain

import emoji
import spacy
import nltk
from nltk.corpus import stopwords
from nltk import bigrams
from nltk.probability import FreqDist
import en_core_web_sm

from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import StandardScaler
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn import metrics

import gensim
from gensim.corpora import Dictionary

import csv
import gensim
from gensim.corpora import Dictionary
from gensim.models import LdaModel
from gensim.models import Phrases
from gensim.models.phrases import Phraser
from gensim.models import CoherenceModel

from datetime import datetime

import pprint
import matplotlib.pyplot as plt
import seaborn as sns
import glob


mallet_path = '.../mallet-2.0.8/bin/mallet'


STOPWORDS = set(stopwords.words('english'))
NLP = en_core_web_sm.load()
nltk.download('words')
words = set(nltk.corpus.words.words())


# Load 116th congress tweets dataset
path = '/Users/...' # replace the path with your own dir

con_df = pd.DataFrame()
for item in glob.glob(path+ "116_congress/*.xlsx"):
    print(item)
    xl_file = pd.ExcelFile(item)
    dfs = {sheet_name: xl_file.parse(sheet_name) 
          for sheet_name in xl_file.sheet_names}
    df = dfs['Sheet1']
    con_df = pd.concat([con_df, df])

print(con_df.shape)


# load 116th congressmen party id information dataset
con_inf = pd.read_csv('party_twitter_complete.csv', dtype=str)


# merge tweets data and party information dataset
df = con_df.merge(con_inf, left_on="user_screen_name", right_on="twitter_handle", how="right") 

df.isnull().sum(axis=0)
df = df[~df.text.isnull()].reset_index(drop=True)
df.head(3)


# generate new variables
df['date'] = [datetime.strptime(item, '%a %b %d %H:%M:%S %z %Y').date() for item in df.created_at.tolist()]
df['year'] = [item.year for item in df.date.tolist()]
df['month'] = [item.month for item in df.date.tolist()]
df['day'] = [item.day for item in df.date.tolist()]
df['year_month'] = [item.strftime('%Y-%m') for item in df.date.tolist()]
df['hashtag_lower'] = [x.lower() if isinstance(x, str) else x for x in df.hashtags]
df['at'] = df['text'].apply(lambda x: re.findall(r'@(\w+)', x))


# filter tweets within 2020.01 -- 2021-09
paper_df = df[df.year_month.isin(['2021-03', '2021-02', '2021-01', '2020-12', '2020-11', '2020-10',
       '2020-09', '2020-08', '2020-07', '2020-06', '2020-05', '2020-04',
       '2020-03', '2020-02', '2020-01', '2021-07', '2021-06', '2021-05',
       '2021-04'])].reset_index(drop=True)
paper_df.shape


# define function to clean tweets
def cleaner(tweet):
    # remove &amp in tweets
    tweet = tweet.replace("amp", "")
    tweet = re.sub("@[A-Za-z0-9]+","",tweet) #Remove @sign
    tweet = re.sub(r"(?:\@|http?\://|https?\://|www)\S+", "", tweet) #Remove http links
    tweet = " ".join(tweet.split())
    tweet = ''.join(c for c in tweet if c not in emoji.UNICODE_EMOJI) #Remove Emojis
    tweet = tweet.replace("#", "").replace("_", " ") #Remove hashtag sign but keep the text
    tweet = " ".join(w.lower() for w in tweet.split()          if w.lower() in words or not w.isalpha())
    tweet = re.sub(r'[^\w\s]', '', tweet)
    tweet = tweet.split() #tokenize
    tweet = [token for token in tweet if len(token) > 1 and token not in STOPWORDS] #remove stopwords
    tweet = NLP(" ".join(tweet)) # lemmatization
    tweet = [token.lemma_ for token in tweet if token.pos_ in ['NOUN', 'ADJ', 'VERB', 'ADV', 'NUM']] 
    return tweet


paper_df['tweet'] = paper_df['text'].map(lambda x: cleaner(x))
paper_df.head(1)


# ## Exploratory analysis
# - figure about the tweet frequency among Rep and Dem across time
# - find all covid-related tweets by using keywords and hashtags
# - figure about the covid-19 related tweets 
#     - the covid-19 related tweet frequency among Rep and Dem across time
#     - most frequent unigram/bigram among Rep and Dem 

paper_df.groupby(['year', 'month'])['party_id'].value_counts().\
    unstack().plot.barh(color={"R": "red", "D": "blue", 'I': 'gray'})


# define relevant keywords for later filtering tweets
key_word = ['covid', 'ncov', 'pandemic', 'virus', 'corona', 'covid19', 'covid-19', 'vax', 'vaccine',
           'vaccination', 'vaccinated', 'immunization', 'vaccinate', 'vaccinated']
tw_lst = []
for ix, row in paper_df.iterrows():
    if any(item in row["text"].lower() for item in key_word):
        tw_lst.append(ix) 
        
print(len(tw_lst))
df2021_cov = paper_df.loc[tw_lst]
df2021_cov.shape

# compare dems and reps about covid
df2021_cov.groupby(['year', 'month'])['party_id'].value_counts().\
    unstack().plot.barh(color={"R": "red", "D": "blue", 'I': 'gray'})


democrat_tweets = list(itertools.chain(*df2021_cov[df2021_cov.party_id == "D"]['tweet'].tolist()))
republican_tweets = list(itertools.chain(*df2021_cov[df2021_cov.party_id == "R"]['tweet'].tolist()))
fdist_democrat = FreqDist(democrat_tweets)
fdist_republican=FreqDist(republican_tweets)

# plot tweets freq
plt.subplots(figsize=(7,3))
fdist_democrat.plot(30,title="Democrat Tweets")
plt.subplots(figsize=(7,3))
fdist_republican.plot(30,title="Republican Tweets")


# most frequent bigrams mentioned by Dem about covid-19
D_terms_bigram = [list(bigrams(tweet)) for tweet in df2021_cov[df2021_cov.party_id == 'D'].tweet.tolist()]
# Flatten list of bigrams in clean tweets
D_all_bigrams= list(itertools.chain(*D_terms_bigram))
# Create counter of words in clean bigrams
D_bigram_counts = collections.Counter(D_all_bigrams)

# most frequent bigrams mentioned by Rep about covid-19
R_terms_bigram = [list(bigrams(tweet)) for tweet in df2021_cov[df2021_cov.party_id == 'R'].tweet.tolist()]
# Flatten list of bigrams in clean tweets
R_all_bigrams= list(itertools.chain(*R_terms_bigram))
# Create counter of words in clean bigrams
R_bigram_counts = collections.Counter(R_all_bigrams)


# figure about the bigrams mentioned by Rep and Dem
bi_fdist_democrat = FreqDist(D_all_bigrams)
bi_fdist_republican = FreqDist(R_all_bigrams)

plt.subplots(figsize=(9,3))
bi_fdist_democrat.plot(50,title="Democrat Tweets")
plt.subplots(figsize=(9,3))
bi_fdist_republican.plot(50,title="Republican Tweets")


# study 1 figures
a1 = dict(bi_fdist_republican)
b1 = sorted(a1.items(),key=operator.itemgetter(1),reverse=True)[:45]

a2 = dict(bi_fdist_democrat)
b2 = sorted(a2.items(),key=operator.itemgetter(1),reverse=True)[:45]
#sorted(a1, key=a1.get, reverse=True)

c1 = pd.DataFrame(b1, columns = ['phrase', 'cnt_r'])
c2 = pd.DataFrame(b2, columns = ['phrase', 'cnt_d'])

c = pd.merge(c1, c2, on='phrase', how='outer')
c = c.fillna(0)
c['perc_r'] = c['cnt_r']/ sum(c['cnt_r'])
c['perc_d'] = c['cnt_d']/ sum(c['cnt_d'])
c['diff_perc'] = (c['perc_r'] - c['perc_d']) 
c['diff_cnt'] = c['cnt_r'] - c['cnt_d']

f1 = pd.concat([c.sort_values(by='diff_cnt', ascending=False).head(15), 
           c.sort_values(by='diff_cnt', ascending=False).tail(15)]).reset_index(drop=True).iloc[1:]
print(f1.shape)

f1['color'] = ['firebrick' if item > 0 else 'mediumblue' for item in f1.diff_cnt]
f3 = dict(zip(f1.phrase, f1.diff_perc))


# plot figure 1
fig, ax = plt.subplots(figsize=(10, 7))
f2 = dict(zip(f1.phrase, f1.diff_cnt))
f3 = {k: v for k, v in sorted(f2.items(), key=lambda item: item[1], reverse=True)}

keys = list(f3.keys())
# get values in the same order as keys, and parse percentage values
vals = [f3[k] for k in keys]
sns.barplot(y=[str(k) for k in keys], x=vals, palette = f1.color).    set(title='Absolute difference in bigrams used by political party')

ax1.set_xticks(range(-4500, 1000, 500))
ax1.set_xticklabels(['', '4000 more \n Democratic', '','3000 more \n Democratic', '','2000 more \n Democratic', 
                     '','1000 more \n Democratic', '','same', '500 more \n Republican'])

# LDA
# build the bigram and trigram models
docs = df2021_cov['tweet'].tolist()
bigram = Phrases(docs, min_count=2)
trigram = Phrases(bigram[docs])  

# get bigram and trigram
bigram_model = Phraser(bigram)
trigram_model = Phraser(trigram)

docs = [trigram_model[bigram_model[doc]] for doc in docs]

# create dictionary
id2word = Dictionary(docs)
# filter out words that occur more than 50% of the documents.
id2word.filter_extremes(no_above=0.5)

# doc term frequency
corpus = [id2word.doc2bow(doc) for doc in docs]


# tune parameter based on coherence_values
coherence_values = []
model_list = []
for num_topics in range(6, 41, 2):
    model = gensim.models.wrappers.LdaMallet(
        mallet_path,
        corpus=corpus,
        id2word=id2word,
        num_topics=num_topics
    )
    model_list.append(model)
    coherencemodel = CoherenceModel(model=model, texts=docs, dictionary=id2word, coherence='c_v')
    coherence_values.append(coherencemodel.get_coherence())



# plot coherence values and select optimal model
x = range(6, 41, 2)
optimal_idx = coherence_values.index(np.max(coherence_values))
optimal_model = model_list[optimal_idx]
optimal_num_topics = x[optimal_idx]

print(optimal_idx, optimal_num_topics)
pprint.pprint(optimal_model.print_topics(num_words=20))



# present the topics for each tweet
def format_topics_sentences(ldamodel, corpus, texts):
    # Init output
    sent_topics_df = pd.DataFrame()

    # Get main topic in each document
    for i, row in enumerate(ldamodel[corpus]):
        row = sorted(row, key=lambda x: (x[1]), reverse=True)
        # Get the Dominant topic, Perc Contribution and Keywords for each document
        for j, (topic_num, prop_topic) in enumerate(row):
            if j == 0:  # => dominant topic
                wp = ldamodel.show_topic(topic_num)
                topic_keywords = ", ".join([word for word, prop in wp])
                sent_topics_df = sent_topics_df.append(pd.Series([int(topic_num), round(prop_topic,4), topic_keywords]), ignore_index=True)
            else:
                break
    sent_topics_df.columns = ['Dominant_Topic', 'Perc_Contribution', 'Topic_Keywords']

    # Add original text to the end of the output
    contents = pd.Series(texts)
    sent_topics_df = pd.concat([sent_topics_df, contents], axis=1)
    return(sent_topics_df)

df_topic_sents_keywords = format_topics_sentences(ldamodel=optimal_model, corpus=corpus, texts=docs)
df_dominant_topic = df_topic_sents_keywords.reset_index()
df_dominant_topic.columns = ['Document_No', 'Dominant_Topic', 'Topic_Perc_Contrib', 'Keywords', 'Text']
df_dominant_topic.head(3)


# merge raw dataset with party information with tweets topics to compare partisan topics
dfm = pd.concat([df2021_cov[['user_screen_name', 'text', 
                 'party_id', 'year_month']], df_dominant_topic], axis=1)
dfm.head(3)

df_groupby = dfm[dfm.party_id == 'D'].groupby('Dominant_Topic').size().reset_index(name='cnt')
df_groupby['perc'] = df_groupby.cnt / dfm[dfm.party_id == 'D'].    groupby('Dominant_Topic').size().agg({'total': 'sum'}).values[0]
td = df_groupby.sort_values(by='perc', ascending=False)

df_groupby = dfm[dfm.party_id == 'R'].groupby('Dominant_Topic').size().reset_index(name='cnt')
df_groupby['perc'] = df_groupby.cnt / dfm[dfm.party_id == 'R'].    groupby('Dominant_Topic').size().agg({'total': 'sum'}).values[0]
tr = df_groupby.sort_values(by='perc', ascending=False)

td.rename(columns = {'perc': 'dperc', 'cnt': 'dcnt'}, inplace=True)
tr.rename(columns = {'perc': 'rperc', 'cnt': 'rcnt'}, inplace=True)

trd = td.merge(tr, left_on='Dominant_Topic', right_on='Dominant_Topic')
trd['topic'] = [str(int(item)) for item in trd.Dominant_Topic.tolist()]
topic_word = df_topic_sents_keywords[['Dominant_Topic', 'Topic_Keywords']].    drop_duplicates()
trd1 = trd.merge(topic_word, left_on = 'Dominant_Topic', right_on='Dominant_Topic', how='left')


trd1['diff_perc'] = trd1['rperc'] - trd1['dperc']
trd1['color'] = ["firebrick" if item >0 else "mediumblue" for item in trd1.diff_perc]
trd2 = trd1.sort_values(by='diff_perc', ascending=False).reset_index(drop=True)


# summarize the topics based upon keywords
trd2.Topic_Keywords.tolist()

# plot topics differences among partisans
fig, ax = plt.subplots(figsize=(8, 5))
p2 = dict(zip(trd2.topic, trd2.diff_perc))
p3 = {k: v for k, v in sorted(p2.items(), key=lambda item: item[1], reverse=True)}

keys = list(p3.keys())
# get values in the same order as keys, and parse percentage values
vals = [p3[k] for k in keys]
sns.barplot(y=[str(k) for k in keys], x=vals, palette = trd2.color).\
    set(title='Percentage difference in topics used by political party')


# try k-means model on partisan tweets
# dems
docs = df2021_cov[df2021_cov.party_id == "D"].tweet.tolist()
id2word = Dictionary(docs)
id2word.filter_extremes(no_above=0.5)

docs = [" ".join(doc) for doc in docs]
tfidf_vectorizer = TfidfVectorizer(
    max_df=0.8, 
    max_features=10000,
    ngram_range=(1,3)
)
scaler = StandardScaler(with_mean=False)
X = tfidf_vectorizer.fit_transform(docs)

silhouette_list = []
kmeans_model = []
k_candidates = list(range(2, 21, 2))
for k in k_candidates:
    model = KMeans(
        n_clusters=k,
        random_state=2020,
        init="k-means++",
        max_iter=100,
        n_init=8
        #n_jobs=-1
    )
    model.fit(X)
    silhouette_list.append(metrics.silhouette_score(X, model.labels_))
    kmeans_model.append(model)
    
optimal_res  = np.max(silhouette_list)
optimal_idx = silhouette_list.index(optimal_res)
optimal_model = kmeans_model[optimal_idx]

df_cld = df2021_cov[df2021_cov.party_id == "D"].reset_index(drop=True)
df_cld['cluster'] = optimal_model.labels_

df_cld1 = df_cld.groupby('cluster').size().reset_index(name = 'cnt')
df_cld1['perc'] = df_cld1['cnt']/ sum(df_cld1['cnt'])
df_cld1['party'] = 'D'

res_c = {}
for i in range(optimal_k):
    t = []
    for ind in order_centroids[i, :10]:
        t.append("_".join(terms[ind].split()))
    res_c[i] = t
    

df_cld1['topic'] = [res_c[item] for item in df_cld1.cluster]


#rep
docs = df2021_cov[df2021_cov.party_id == "R"].tweet.tolist()
id2word = Dictionary(docs)
id2word.filter_extremes(no_above=0.5)

docs = [" ".join(doc) for doc in docs]
tfidf_vectorizer = TfidfVectorizer(
    max_df=0.8, 
    max_features=10000,
    ngram_range=(1,3))
scaler = StandardScaler(with_mean=False)
X = tfidf_vectorizer.fit_transform(docs)

silhouette_list = []
kmeans_model = []

k_candidates = list(range(2, 21, 2))
for k in k_candidates:
    model = KMeans(
        n_clusters=k,
        random_state=2020,
        init="k-means++",
        max_iter=100,
        n_init=8
        #n_jobs=-1
    )
    model.fit(X)
    silhouette_list.append(metrics.silhouette_score(X, model.labels_))
    kmeans_model.append(model)
    
optimal_res  = np.max(silhouette_list)
optimal_idx = silhouette_list.index(optimal_res)
optimal_model = kmeans_model[optimal_idx]

df_clr = df2021_cov[df2021_cov.party_id == "R"].reset_index(drop=True)
df_clr['cluster'] = optimal_model.labels_

df_clr1 = df_clr.groupby('cluster').size().reset_index(name = 'cnt')
df_clr1['perc'] = df_clr1['cnt']/ sum(df_clr1['cnt'])
df_clr1['party'] = 'R'

res_c = {}
for i in range(optimal_k):
    t = []
    for ind in order_centroids[i, :10]:
        t.append("_".join(terms[ind].split()))
    res_c[i] = t
    
df_clr1['topic'] = [res_c[item] for item in df_clr1.cluster]


# based on the keywords summarize the topics
cluster_sep = pd.concat([df_clr1, df_cld1]).reset_index(drop=True)

# load the summarized topics results:
cluster_sep2 = pd.read_csv('.../cluster_sep2.csv')
fig, ax = plt.subplots(figsize=(10, 7))
p2 = dict(zip(cluster_sep2.topic_s, cluster_sep2.diff_perc))
p3 = {k: v for k, v in sorted(p2.items(), key=lambda item: item[1], reverse=True)}

keys = list(p3.keys())
# get values in the same order as keys, and parse percentage values
vals = [p3[k] for k in keys]
sns.barplot(y=[str(k) for k in keys], x=vals, palette = cluster_sep2.color).    set(title='Percentage difference in topics used by political party')

ax.set_xticklabels(['25%','20% more \n Democratic', '15% more \n Democratic', '10% more \n Democratic', 
                    '5% more \n Democratic', 'same', '5% more \n Republican', '10% more \n Republican', '15% more \n Republican'])



# ## Additional analysis on vaccination related tweets only
# - compare Dems' and Reps' tweets
vax_word = ['vax', 'vaccine', 'vaccination', 'vaccinated', 'immunization', 'vaccinate', 'vaccinated']
tw_lst = []
for ix, row in paper_df.iterrows():
    if any(item in row["text"].lower() for item in key_word):
        tw_lst.append(ix) 

df_vax = paper_df.loc[tw_lst]
df_vax.shape


# compare dems and reps about vaccination
df_vax.groupby(['year', 'month'])['party_id'].value_counts().\
    unstack().plot.barh(color={"R": "red", "D": "blue", 'I': 'gray'})


# compare cdc retweets among reps and dems

df_vax['year_month_day'] = [item.strftime('%Y-%m-%d') for item in df_vax.date.tolist()]
cdc = df_vax[(df_vax.retweet_or_quote_screen_name == 'CDCgov')]
print(cdc[cdc.party_id=='R'].shape[0]/ cdc.shape[0], cdc[cdc.party_id=='D'].shape[0]/ cdc.shape[0], cdc.shape[0])
cdc1 = cdc.groupby(['year_month', 'party_id'])['id'].count().reset_index(name='cnt')

fig, ax = plt.subplots(figsize=(10, 8))
sns.lineplot(x='year_month', y='cnt', hue='party_id', 
             data=cdc1)

cdc.groupby(['year_month', 'party_id'])['id'].count().reset_index(name='cnt')


#################################### Bert model ################
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# use trained BERT model from huggingface to identify covid-related topics 
sample_list = df2021_cov['text'].tolist()
chunked_list = list()
chunk_size = 50
for i in range(0, len(sample_list), chunk_size):
    chunked_list.append(sample_list[i:i+chunk_size])
print(len(chunked_list))


tokenizer = AutoTokenizer.from_pretrained("inovex/multi2convai-corona-en-bert")
model = AutoModelForSequenceClassification.from_pretrained("inovex/multi2convai-corona-en-bert")

total_output = []
for idx, text in enumerate(chunked_list):
    encoded_text = tokenizer(text, return_tensors="pt", padding=True)
    cls_logits = model(**encoded_text).logits
    cls_probs = torch.softmax(cls_logits, dim=1).detach().numpy()
    output = [
        {"label": model.config.id2label[item.argmax()], "score": item.max().item()} for item in cls_probs
    ]
    total_output.append(output)
    print(idx)
    


df2021_cov['label_c']  = [x for sublst in total_output for x in sublst]
df2021_cov['label'] = [x['label'] for x in df2021_cov.label_c]

df2021_cov1 = df2021_cov[df2021_cov.label.str.startswith('corona')].reset_index(drop=True)
dr = df2021_cov1[df2021_cov1.party_id == 'R'].label.value_counts().to_dict()
dd = df2021_cov1[df2021_cov1.party_id == 'D'].label.value_counts().to_dict()


df_label = pd.merge(pd.DataFrame(dr.items(), columns = ['type', 'rcnt']),
                    pd.DataFrame(dd.items(), columns = ['type', 'dcnt']), 
                    left_on='type', right_on='type', how='outer')

# plot to compare partisan covid-related topics 
f, ax = plt.subplots(figsize=(9,6))
ax.plot(df_label.rcnt, df_label.type, color = 'red')
ax.plot(df_label.dcnt, df_label.type, color = 'blue')


df_label['rperc'] = df_label.rcnt/df2021_cov[df2021_cov.party_id == 'R'].shape[0]
df_label['dperc'] = df_label.dcnt/df2021_cov[df2021_cov.party_id == 'D'].shape[0]
df_label['diff_perc'] = df_label['rperc'] - df_label['dperc']
df_label['color'] = ["firebrick" if x > 0 else "mediumblue" for x in df_label.diff_perc]
df_label1 = df_label[abs(df_label.diff_perc) > 0.0002].sort_values(by='diff_perc', ascending=False)
df_label1


fig, ax = plt.subplots(figsize=(10, 7))
p2 = dict(zip(df_label1.type, df_label1.diff_perc))
p3 = {k: v for k, v in sorted(p2.items(), key=lambda item: item[1], reverse=True)}

keys = list(p3.keys())
# get values in the same order as keys, and parse percentage values
vals = [p3[k] for k in keys]
sns.barplot(y=[str(k) for k in keys], x=vals, palette = df_label1.color).    set(title='Percentage difference in topics used by political party')





#### only look at the tweets which have been identified as covid-related by BERT
df2021_cov1['tweet'] = df2021_cov1['text'].map(lambda x: cleaner(x))
df2021_vax = df2021_cov[(df2021_cov.label == 'corona.vaccine')].reset_index(drop=True)
df2021_vax['tweet'] = df2021_vax['text'].map(lambda x: cleaner(x))


# Try LDA 
# build the bigram and trigram models
docs = df2021_vax['tweet'].tolist()
bigram = Phrases(docs, min_count=2)
trigram = Phrases(bigram[docs])  

# get bigram and trigram
bigram_model = Phraser(bigram)
trigram_model = Phraser(trigram)

docs = [trigram_model[bigram_model[doc]] for doc in docs]
# create dictionary
id2word = Dictionary(docs)
# filter out words that occur more than 50% of the documents.
id2word.filter_extremes(no_above=0.5)

# doc term frequency
corpus = [id2word.doc2bow(doc) for doc in docs]


coherence_values = []
model_list = []
for num_topics in range(6, 31, 2):
    model = gensim.models.wrappers.LdaMallet(
        mallet_path,
        corpus=corpus,
        id2word=id2word,
        num_topics=num_topics
    )
    model_list.append(model)
    coherencemodel = CoherenceModel(model=model, texts=docs, dictionary=id2word, coherence='c_v')
    coherence_values.append(coherencemodel.get_coherence())


x = range(6, 31, 2)
plt.plot(x, coherence_values)
plt.xlabel("Num Topics")
plt.ylabel("Coherence score")
plt.legend(("coherence_values"), loc='best')
plt.show()

optimal_idx = coherence_values.index(np.max(coherence_values))
optimal_model = model_list[optimal_idx]
optimal_num_topics = x[optimal_idx]

print(optimal_idx, optimal_num_topics)
pprint.pprint(optimal_model.print_topics(num_words=10))


df_topic_sents_keywords = format_topics_sentences(ldamodel=optimal_model, corpus=corpus, texts=docs)
df_dominant_topic = df_topic_sents_keywords.reset_index()
df_dominant_topic.columns = ['Document_No', 'Dominant_Topic', 'Topic_Perc_Contrib', 'Keywords', 'Text']
df_dominant_topic.head(3)


df_dominant_topic['party_id'] = df2021_vax['party_id'] 
df_dominant_topic['raw_text'] = df2021_vax['text'] 


vax_topic = df_dominant_topic[df_dominant_topic.party_id != 'I'].\
    groupby(['party_id', 'Keywords'])['Document_No'].count().reset_index(name='cnt')


vax_topic['perc'] = vax_topic['cnt'] / vax_topic.groupby('party_id')['cnt'].transform('sum')
vax_topic

vax_topic1 = vax_topic[vax_topic.party_id =='R'].merge(vax_topic[vax_topic.party_id =='D'], left_on='Keywords', 
                                         right_on='Keywords', how='left')
vax_topic1['diff_perc'] = vax_topic1['perc_x'] - vax_topic1['perc_y'] 
vax_topic2 = vax_topic1.sort_values(by='diff_perc', ascending=False) 

# summarize the topics based on keywords
vax_topic_dict = {'vaccine development safty, effectiveness':
 'vaccine, distribution, covid19, safe_effective, safe, year, good, operation_warp_speed, news, show',
 'federal government response, effort':
 'response, federal, public_health, lead, fight, effort, crisis, treatment, covid19, medical',
 'covid19 live, speak, discuss':
 'covid19, today, great, join, hear, live, discuss, meet, hope, speak',
 'virus stop spread':
 'make, time, virus, give, covid19, stop, trump, public, government, spread',
 'distribute vaccine':
 'people, vaccine, state, million, day, covid19, administration, today, distribute, end',
 "support, relief, bill, rescue plan":
 'support, relief, act, provide, bill, family, legislation, pass, emergency, rescue_plan',
 "health care, community protect":
 'pandemic, health, community, protect, ensure, access, health_care, continue, critical, care',
 'work, save, pandemic':
 'work, pandemic, nation, program, increase, country, economy, save, good, world',
 "covid19 testing":
 'covid19, testing, free, test, plan, expand, crisis, leave, national, key',
 "vaccinate, vaccine, shot":
 'vaccinate, vaccine, covid, part, make_sure, encourage, people, covid19, continue, shoot',
 "vaccination, open, site":
 'today, vaccination, week, county, community, open, learn, local, site, update',
 "vaccine, appointment, call, visit, start":
 'vaccine, covid19, receive, find, visit, call, appointment, information, eligible, start'}


vax_topic2['color'] = ["firebrick" if x >0 else "mediumblue" for x in vax_topic2.diff_perc]
vax_topic2['topic'] = [list(vax_topic_dict.keys())[list(vax_topic_dict.values()).index(x)] for x in vax_topic2.Keywords]
vax_topic2

# plot the results
fig, ax = plt.subplots(figsize=(10, 7))
p2 = dict(zip(vax_topic2.topic, vax_topic2.diff_perc))
p3 = {k: v for k, v in sorted(p2.items(), key=lambda item: item[1], reverse=True)}

keys = list(p3.keys())
vals = [p3[k] for k in keys]
sns.barplot(y=[str(k) for k in keys], x=vals, palette = vax_topic2.color).    set(title='Percentage difference in topics used by political party')
